{% extends 'base.attached.class' %}

{% block content %}
var {{ class_name }} = function(hidden, output, layers, weights, bias) {
    this.hidden = hidden.toUpperCase();
    this.output = output.toUpperCase();
    this.network = new Array(layers.length + 1);
    for (var i = 0, l = layers.length; i < l; i++) {
        this.network[i + 1] = new Array(layers[i]).fill(0.);
    }
    this.weights = weights;
    this.bias = bias;

    var findMax = function(nums) {
        var i, l = nums.length, idx = 0;
        for (i = 0; i < l; i++) {
            idx = nums[i] > nums[idx] ? i : idx;
        }
        return idx;
    };

    var compute = function(activation, v) {
        var i, l = v.length;
        switch (activation) {
            case 'LOGISTIC':
                for (i = 0; i < l; i++) {
                    v[i] = 1. / (1. + Math.exp(-v[i]));
                }
                break;
            case 'RELU':
                for (i = 0; i < l; i++) {
                    v[i] = Math.max(0, v[i]);
                }
                break;
            case 'TANH':
                for (i = 0; i < l; i++) {
                    v[i] = Math.tanh(v[i]);
                }
                break;
            case 'SOFTMAX':
                var max = Number.NEGATIVE_INFINITY;
                for (i = 0; i < l; i++) {
                    if (v[i] > max) {
                        max = v[i];
                    }
                }
                for (i = 0; i < l; i++) {
                    v[i] = Math.exp(v[i] - max);
                }
                var sum = 0.0;
                for (i = 0; i < l; i++) {
                    sum += v[i];
                }
                for (i = 0; i < l; i++) {
                    v[i] /= sum;
                }
                break;
        }
        return v;
    };

    this.resetNetwork = function() {
        for (var i = 1, l = this.network.length - 1; i < l; i++) {
            for (var j = 0; j < this.network[i].length; j++) {
                this.network[i][j] = 0;
            }
        }
    };

    this.feedForward = function(neurons) {
        this.network[0] = neurons;
        for (var i = 0; i < this.network.length - 1; i++) {
            for (var j = 0; j < this.network[i + 1].length; j++) {
                this.network[i + 1][j] = this.bias[i][j];
                for (var l = 0; l < this.network[i].length; l++) {
                    this.network[i + 1][j] += this.network[i][l] * this.weights[i][l][j];
                }
            }
            if ((i + 1) < (this.network.length - 1)) {
                this.network[i + 1] = compute(this.hidden, this.network[i + 1]);
            }
        }
        this.network[this.network.length - 1] = compute(this.output, this.network[this.network.length - 1]);
        this.resetNetwork();
        return this.network[this.network.length - 1];
    };

    this.predict = function(neurons) {
        var lastLayer = this.feedForward(neurons);
        if (lastLayer.length === 1) {
            if (lastLayer[0] > .5) {
                return 1;
            }
            return 0;
        }
        return findMax(lastLayer);
    };

    this.predictProba = function(neurons) {
        var lastLayer = this.feedForward(neurons);
        if (lastLayer.length === 1) {
            return [lastLayer[0], 1 - lastLayer[0]];
        }
        return lastLayer;
    };

};

var main = function () {
    if (typeof process !== 'undefined' && typeof process.argv !== 'undefined') {
        if (process.argv.length - 2 !== {{ n_features }}) {
            var IllegalArgumentException = function(message) {
                this.message = message;
                this.name = "IllegalArgumentException";
            }
            throw new IllegalArgumentException("You have to pass {{ n_features }} features.");
        }
    }

    // Features:
    var features = process.argv.slice(2);
    for (var i = 0; i < features.length; i++) {
        features[i] = parseFloat(features[i]);
    }

    // Model data:
    {{ layers }}
    {{ weights }}
    {{ bias }}

    // Estimator:
    var clf = new {{ class_name }}('{{ hidden_activation }}', '{{ output_activation }}', layers, weights, bias);

    {% if is_test or to_json %}
    // Get JSON:
    console.log(JSON.stringify({
        "predict": clf.predict(features),
        "predict_proba": clf.predictProba(features)
    }));
    {% else %}
    // Get class prediction:
    var prediction = clf.predict(features);
    console.log("Predicted class: #" + prediction);

    // Get class probabilities:
    var probabilities = clf.predictProba(features);
    for (var i = 0; i < probabilities.length; i++) {
        console.log("Probability of class #" + i + " : " + probabilities[i]);
    }
    {% endif %}
}

if (require.main === module) {
    main();
}
{% endblock %}